Implementing GPT2 in JAX for fun 🦀🦀🦀

Author
Published

November 9, 2024

1 GPT2 for JAX 🚀

Explore the full project on the GitHub repository.

1.1 Context ✍️

This project involves rewriting XTTS in JAX to better understand its architecture and functionality. Originally developed by the now-defunct Coqai company, XTTS is a Text-to-Speech model. We’ll recreate its generative component using a GPT2 architecture—a decoder-only transformer—based on (Radford et al. 2019). The implementation closely follows this tutorial.

The crux of the GPT2 architecture. Layers composed of masked attention and forwards.

The crux of the GPT2 architecture. Layers composed of masked attention and forwards.

1.2 GPT2 in Text-to-Speech

1.2.1 What are we building?

Our goal is to generate sequences of tokens for audio synthesis. Specifically, we aim to produce “audio tokens,” small units of audio, discovered using a VQVAE. By learning to map text tokens to audio tokens, the model becomes multi-modal.

The final output sequences represent speech, which we convert into audio using HiFiGAN. Additionally, we enhance speech expressiveness (e.g., tone, speed) by feeding 1024-dimensional vectors representing the target speaker’s paralinguistic features.

1.2.2 Under the Hood

Masked Attention

Masked attention is the core mechanism for learning relationships between tokens. It determines which tokens influence others by projecting them into smaller dimensions and computing relationships. Masking ensures the model focuses only on prior tokens, preventing it from “seeing” future ones.

Studies classify attention patterns into:
1. Semantic: Tokens linked by meaning.
2. Linguistic: Tokens connected by grammar (e.g., verbs and nouns).
3. Rare Tokens: Infrequent but critical tokens.

Feedforward Layers

Feedforward layers mix outputs, add non-linearity via activation functions, and stack layers for hierarchical abstractions. The final output approximates a one-hot encoding in the token vocabulary, enabling token selection for sequential generation.

1.3 Goal 🎯

Implement a GPT2 architecture using Equinox and train it on TinyStories.

2 Model

We have a few things to implement from the ground up. The custom activation function, the forward layer, the masked attention. We then package this up in a nice layer that we can stack, and finally wrap all these stacks into a GPT2 !

We can start by importing our favorite libraries 🥰

import jax
import equinox as eqx
import equinox.nn as nn
import jax.numpy as jnp
import typing as tp

2.1 Configuration file

Because of the size of our model, we’re going to be passing down lots of arguments. To avoid having a long unreadable list of parameters we can define a “dataclass” that will allow us to simply pass a config down to the model.

Feel free to experiment with various settings !

from dataclasses import dataclass


@dataclass
class GPTConfig:
    block_size: int = 100
    vocab_size: int = (
        50304  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    )
    n_layer: int = 16
    n_head: int = 12
    n_embd: int = 1024
    dropout: float = 0.1
    bias: bool = True  #

2.2 SwiGLU Activation Function

We start by implementing the SwiGLU activation function, introduced in (Shazeer 2020), a powerful variant of GLU.

2.2.1 Why SwiGLU?

SwiGLU dynamically adjusts its activation based on the input. Think of it like a railway switch—redirecting the “activation path” when the input carries different information. This gives the network greater flexibility and control, leading to better performance.

For more details, see this explanation by Boudefel.

The function we implement, based on (Shazeer 2020)

The function we implement, based on (Shazeer 2020)

Below is a visualization of the Swish function, \(x \times \text{sigmoid}(x)\), which plays a role in SwiGLU:

class SwiGLU(eqx.Module):
    W: nn.Linear
    V: nn.Linear
    b: jax.Array
    c: jax.Array

    def __init__(self, input_dim, output_dim, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.W = nn.Linear(input_dim, output_dim, key=key1)
        self.V = nn.Linear(input_dim, output_dim, key=key2)
        self.b = jax.random.normal(key3, (output_dim))
        self.c = jax.random.normal(key4, (output_dim))

    @eqx.filter_jit
    def __call__(self, x):
        return jax.nn.swish((self.W(x) + self.b) * (self.V(x) + self.c))
Code
key = jax.random.PRNGKey(69)
mod = SwiGLU(10, 4, key)

x = jnp.ones(10)
print(mod(x).shape)

2.3 MLP

We can now move onto the multilayer perceptron, which we mentionned earlier as the feedforward part of our network. Because the model is big and we want to make sure that it doesn’t just “memorize” things, we include dropout which pushes the model to avoid relying on singular neurons / data flowing through for information.

✨ You’ll also notice that since our SwiGLU has two linear layers in it, in reality each MLP that we’ll use uses 4 layers !!

class MLP(eqx.Module):
    ff1: nn.Linear
    ff2: nn.Linear
    act: SwiGLU
    drop: nn.Dropout

    def __init__(self, config, key):

        key1, key2, key3 = jax.random.split(key, 3)

        self.ff1 = nn.Linear(
            config.n_embd, 4 * config.n_embd, use_bias=config.bias, key=key1
        )
        self.act = SwiGLU(4 * config.n_embd, 4 * config.n_embd, key=key2)
        self.ff2 = nn.Linear(
            4 * config.n_embd, config.n_embd, use_bias=config.bias, key=key3
        )
        self.drop = nn.Dropout(config.dropout, deterministic=True)

    @eqx.filter_jit
    def __call__(self, x):
        y = self.ff1(x)
        y = self.act(y)
        y = self.ff2(y)
        return self.drop(y)

2.4 Masked attention

Moving onto one of the more complicated aspects of the model, but in the end it simply learns to output which tokens are more important with each other. There are plenty of fantastic tutorials out there for better understanding the underlying concept, notably : Transformers explained visually

Again, our query, key and values are all tokens that have passed a Feedforward layer to then try to find relationships between them:

Finding tokens that match, Query \times Key will give you the “closeness” (Jalammar 2019)

Finding tokens that match, Query \(\times\) Key will give you the “closeness” (Jalammar 2019)
import math


class CausalSelfAttention(eqx.Module):
    attnk: nn.Linear
    attnq: nn.Linear
    attnv: nn.Linear
    proj: nn.Linear

    resid_dropout: nn.Dropout
    attn_dropout: nn.Dropout

    mask: jax.Array

    def __init__(self, config, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)

        self.attnk = nn.Linear(
            config.n_embd, config.n_embd, use_bias=config.bias, key=key1
        )
        self.attnv = nn.Linear(
            config.n_embd, config.n_embd, use_bias=config.bias, key=key2
        )
        self.attnq = nn.Linear(
            config.n_embd, config.n_embd, use_bias=config.bias, key=key3
        )
        self.attn_dropout = nn.Dropout(config.dropout, deterministic=True)
        self.resid_dropout = nn.Dropout(config.dropout, deterministic=True)

        self.proj = nn.Linear(
            config.n_embd, config.n_embd, use_bias=config.bias, key=key4
        )

        self.mask = jnp.tril(jnp.ones((config.block_size, config.block_size)))

    # Could play arround with the different attention score calculations (Baidhu ?)
    # X is an embedding, it should self attend.

    @eqx.filter_jit
    def __call__(self, x):
        # x = jnp.swapaxes(x, -1, -2)
        T, C = x.shape  # Seq length and embedding dim.

        q = jax.vmap(self.attnq)(x)
        k = jax.vmap(self.attnk)(x)
        v = jax.vmap(self.attnv)(x)

        att = jnp.matmul(q, jnp.transpose(k)) / math.sqrt(jnp.shape(k)[-1])
        att = jnp.where(
            jax.numpy.equal(jax.lax.stop_gradient(self.mask[:T, :T]), 0),
            float("-inf"),
            att,
        )
        att = jax.nn.softmax(att, axis=-1)
        att = self.attn_dropout(att)

        y = jnp.matmul(att, v)

        y = jax.vmap(self.proj)(y)
        y = self.resid_dropout(y)
        return y

Small check…

Code
import optax


config = GPTConfig()
key = jax.random.PRNGKey(69)

mlp = CausalSelfAttention(config, key)
optimizer = optax.adam(1e-5)
opt_state = optimizer.init(mlp)

x = jax.random.normal(jax.random.key(2), (30, config.n_embd))


@eqx.filter_jit
def calculate_loss(model, x, y):
    output = model(x)
    return jax.numpy.mean(jax.numpy.abs(y - output))


def make_step(model, opt_state, x, y):
    loss_step, grads = eqx.filter_value_and_grad(calculate_loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, loss_step


print(make_step(mlp, opt_state, x, x))

print(mlp(jax.random.normal(key, (100, config.n_embd))).shape)

2.5 Block

Ok ! Now that we have the component parts of what we call a “block” we can assemble them. This will then be stacked to get as many layers of abstraction as we wish. In our case we will stack it 12 times as per the GPTConfig we defined.

class Block(eqx.Module):
    norm: nn.LayerNorm
    attn: CausalSelfAttention
    mlp: MLP

    def __init__(self, config, key):
        key1, key2 = jax.random.split(key, 2)

        self.norm = nn.LayerNorm(config.n_embd, use_bias=config.bias)
        self.attn = CausalSelfAttention(config, key=key1)
        self.mlp = MLP(config, key=key2)

    @eqx.filter_jit
    def __call__(self, x):
        y = jax.vmap(self.norm)(x)
        y = self.attn(
            y
        )  # Can't vmap as the whole point is exchange info between tokens.
        x = y + x

        y = jax.vmap(self.norm)(x)
        y = jax.vmap(self.mlp)(y)
        x = y + x

        return x

Can compare with their work.

Code
import optax


config = GPTConfig()
key = jax.random.PRNGKey(69)

block = Block(config, key)
optimizer = optax.adam(1e-5)
opt_state = optimizer.init(block)

x = jax.random.normal(jax.random.key(2), (30, config.n_embd))


@eqx.filter_jit
def calculate_loss(model, x, y):
    output = model(x)
    return jax.numpy.mean(jax.numpy.abs(y - output))


def make_step(model, opt_state, x, y):
    loss_step, grads = eqx.filter_value_and_grad(calculate_loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, loss_step


print(make_step(block, opt_state, x, x))

We can finally add the embeddings to our model, which are the maps that send tokens to the dimension that the model works with, i.e. 1024 dims.

class GPT(eqx.Module):
    wte: nn.Embedding  # Token embeddings
    wpe: nn.Embedding  # Positional embeddings

    drop: nn.Dropout

    layers: list
    norm: nn.LayerNorm
    lm_head: nn.Linear

    def __init__(self, config, key):
        key1, key2, key3, key4, key5 = jax.random.split(key, 5)

        self.wte = nn.Embedding(config.vocab_size, config.n_embd, key=key1)
        self.wpe = nn.Embedding(config.block_size, config.n_embd, key=key2)
        self.drop = nn.Dropout(config.dropout, deterministic=True)

        self.layers = [
            Block(config, key) for key in jax.random.split(key3, config.n_layer)
        ]
        self.norm = nn.LayerNorm(config.n_embd, use_bias=config.bias)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, key=key4)

    @eqx.filter_jit
    def __call__(self, token_ids):
        (t,) = token_ids.shape

        # Should use better positional embeddings with cos and sin.
        pos = jnp.arange(0, t)
        tok_emb = jax.vmap(self.wte)(token_ids)
        pos_emb = jax.vmap(self.wpe)(pos)

        # Dropout at the first layer ? Seems a bit aggressive...
        x = self.drop(tok_emb + pos_emb)

        for block in self.layers:
            x = block(x)
        x = jax.vmap(self.norm)(x)
        logits = jax.vmap(self.lm_head)(x)
        # logits = jax.nn.softmax(logits)

        return logits
Code
import optax


config = GPTConfig()
key = jax.random.PRNGKey(69)

block = GPT(config, key)
optimizer = optax.adam(1e-5)
opt_state = optimizer.init(block)

x = jax.numpy.ones((30, 128), dtype=jax.numpy.int32)


# @eqx.filter_jit
def calculate_loss(model, x, y):
    output = jax.vmap(model)(x)
    return jax.numpy.mean(
        jax.vmap(optax.softmax_cross_entropy_with_integer_labels)(output, y)
    )


def make_step(model, opt_state, x, y):
    loss_step, grads = eqx.filter_value_and_grad(calculate_loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, loss_step


print(make_step(block, opt_state, x, x))

3 Training

We can now move onto training the model ! We’re going to be using the TinyStories (Eldan and Li 2023) dataset. Tiktoken is used to map the sentences to sequences of tokens that the model would understand. Below is the code to download and transform the data into a binary file, and then provide it with a dataloader to our training regime.

Example Tinystories:

Once upon a time, in a small house, there lived a boy named Tim. Tim was a selfish boy. He did not like to share his things with others. One day, his mom brought home a big bag of vegetables. She wanted to...
Code
# saves the openwebtext dataset to a binary file for training. following was helpful:
# https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py

import os
from tqdm import tqdm
import numpy as np
import tiktoken
from datasets import load_dataset  # huggingface datasets

# number of workers in .map() call
# good number to use is ~order number of cpu cores // 2
num_proc = 16

dataset = load_dataset("roneneldan/TinyStories")

# we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
enc = tiktoken.get_encoding("gpt2")


def process(example):
    ids = enc.encode_ordinary(
        example["text"]
    )  # encode_ordinary ignores any special tokens
    ids.append(enc.eot_token)  # add the end of text token, e.g. 50256 for gpt2 bpe
    # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
    out = {"ids": ids, "len": len(ids)}
    return out


# tokenize the dataset
tokenized = dataset.map(
    process,
    remove_columns=["text"],
    desc="tokenizing the splits",
    num_proc=num_proc,
)

# concatenate all the ids in each dataset into one large file we can use for training
for split, dset in tokenized.items():
    arr_len = np.sum(dset["len"])
    filename = os.path.join(os.path.dirname("dataset"), f"{split}.bin")
    dtype = np.uint16  # (can do since enc.max_token_value == 50256 is < 2**16)
    arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
    total_batches = 1024

    idx = 0
    for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
        # Batch together samples for faster write
        batch = dset.shard(
            num_shards=total_batches, index=batch_idx, contiguous=True
        ).with_format("numpy")
        arr_batch = np.concatenate(batch["ids"])
        # Write into mmap
        arr[idx : idx + len(arr_batch)] = arr_batch
        idx += len(arr_batch)
    arr.flush()

We can now load the code from the compressed binary representation to the inputs and outputs. Since we want the GPT to learn to predict the next token, we simply shift the input by 1 !

Code
import os
import jax.numpy as jnp
import numpy

data_dir = "dataset"
config = GPTConfig()


def get_batch(split: str):
    # We recreate jnp.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == "train":
        data = numpy.memmap(
            os.path.join(data_dir, "train.bin"), dtype=numpy.uint16, mode="r"
        )
    else:
        data = numpy.memmap(
            os.path.join(data_dir, "validation.bin"), dtype=numpy.uint16, mode="r"
        )

    ix = numpy.random.randint(len(data) - config.block_size, size=(32,))
    x = jnp.stack([jnp.array(data[i : i + config.block_size]) for i in ix])
    y = jnp.stack([jnp.array(data[i + 1 : i + 1 + config.block_size]) for i in ix])

    return x, y

We can now define our loss function. Our goal here is to motivate the model to output something close to [0, 0, 0, …, 1,…, 0, 0] where the 1 is placed at the \(n\)th index. This index would ideally correspond to the word we’re attempting to match. optax, the ML optimisation library of JAX conveniently has a function for this.

import optax

learning_rate = 1e-4
warmup_iters = 3
init_from = "scratch"
lr_decay_iters = 20
iter_num = 0
min_lr = 1e-6

lr_scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=learning_rate,
    warmup_steps=warmup_iters if init_from == "scratch" else 0,
    decay_steps=lr_decay_iters - iter_num,
    end_value=min_lr,
)

optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=learning_rate)
# optimizer = optax.adamw(1e-2)


@eqx.filter_jit
def calculate_loss(model, x, y):
    output = jax.vmap(model)(x)
    return jax.numpy.mean(optax.softmax_cross_entropy_with_integer_labels(output, y))


def make_step(model, optimizer_state, x, y):
    loss, grads = eqx.filter_value_and_grad(calculate_loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, optimizer_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, loss

We can now move onto initializing our model and training it ! We can log the progress on wandb to see the loss curve.

Two runs compared

Two runs compared

I train two seperate models, below are the various parameters I changed:

Experiment ID Learning Rate Block size Dropout N_embed Layers
Run1 1e-4 100 0.1 1024 16
Run2 1e-4 100 0.0 512 12
import tiktoken
import wandb

key = jax.random.PRNGKey(69)

gptconf = GPTConfig()
model = GPT(gptconf, key)

wandb.init(project="gpt-training", config=gptconf.__dict__)

optimizer_state = optimizer.init(model)
num_iterations = 1000

enc = tiktoken.get_encoding("gpt2")
for local_iter_num in range(num_iterations):
    x, y = get_batch("train")
    # Perform a single training step
    model, optimizer_state, loss = make_step(model, optimizer_state, x, y)

    wandb.log({"loss": loss, "iteration": local_iter_num})
    print(f"loss {loss}, iteration {local_iter_num}")

Different experiments show varying results:

After training, we can save the model to a local directory to then use it for inference. I quickly check whether the model produces gibberish or not :

eqx.tree_serialise_leaves("gpt2.eqx", model)

With the code below, we get a linguistically acceptable output:

Once upon a time, there was a little girl named Lily. She was so happy.
The little girl was so happy to the park
enc = tiktoken.get_encoding("gpt2")
# print(enc.special_tokens_set)
start = "Once upon"
x = jax.numpy.array([enc.encode(start)])

while x[0, -1] != enc.eot_token:
    logits = jax.vmap(model)(x)
    x = jax.numpy.concat(
        [x, jax.numpy.array([[jax.numpy.argmax(logits[0, -1])]])], axis=-1
    )
    print(enc.decode(jax.numpy.squeeze(x, axis=0)))

This concludes the training of a rudimentary GPT2 model. To go further, we could implement key-value caching which allows us to reduce redundant calculations. We could also implement beam search to find more optimal sequences of words, and use top-k and top-p sampling to add diversity along with temperature. This is left as an exercise to the reader

References

Eldan, Ronen, and Yuanzhi Li. 2023. “TinyStories: How Small Can Language Models Be and Still Speak Coherent English?” https://arxiv.org/abs/2305.07759.
Jalammar. 2019. “The Illustrated Transformer.” https://jalammar.github.io/illustrated-gpt2/.
Radford, Alec, Jeff Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. 2019. “Language Models Are Unsupervised Multitask Learners.”
Shazeer, Noam. 2020. “GLU Variants Improve Transformer.” https://arxiv.org/abs/2002.05202.